#!/usr/bin/env python3
"""
run_sweeps.py
==============

Execute the volume sweeps for the agent‑generated kernels and flip counts.
This script reads the experiment configuration from a YAML file, iterates
over all parameter combinations (``b``, ``k``, ``n0``, ``L``) and gauge
groups, computes the fundamental and adjoint string tensions, and writes
the results to a CSV file.

The computation uses the simplified formula

    σ_fund = b * mean(kernel ** k)

and applies the SU(N) Casimir scaling to obtain σ_adj.  Flip counts are
loaded but not directly used; they are required to maintain API
compatibility with future models.

Results are written to ``{output_dir}/adjoint_volume_summary.csv``.

Example usage::

    python run_sweeps.py --config configs/default.yaml --output-dir data/results/vol4_wilson_loop_adjoint_volume_sweep

"""

import argparse
import csv
import os
from itertools import product
from typing import Sequence, Dict, Any, List

import numpy as np
import yaml


def fundamental_string_tension(kernel: np.ndarray, b: float, k_exp: float) -> float:
    """Compute the fundamental string tension from the kernel.

    Parameters
    ----------
    kernel : numpy.ndarray
        One‑dimensional array of non‑negative values.
    b : float
        Coupling constant.
    k_exp : float
        Exponent applied to each kernel entry prior to averaging.

    Returns
    -------
    float
        The fundamental string tension σ_fund.
    """
    # Ensure the kernel is non‑negative
    arr = kernel.astype(float)
    # Raise to the power k_exp and compute the mean
    powered = arr ** k_exp
    return float(b * powered.mean())


def adjoint_string_tension(sigma_fund: float, gauge: str) -> float:
    """Compute the adjoint string tension using Casimir scaling.

    Parameters
    ----------
    sigma_fund : float
        The fundamental string tension.
    gauge : str
        Gauge label (e.g. "SU2" or "SU3").

    Returns
    -------
    float
        The adjoint string tension σ_adj.
    """
    N = int(gauge[-1])
    factor = (2 * (N ** 2)) / ((N ** 2) - 1)
    return sigma_fund * factor


def load_kernel(cfg: Dict[str, Any], gauge: str, L: int) -> np.ndarray:
    """Load the kernel for a given gauge and lattice size.

    The configuration may specify a template under ``kernel_paths`` or
    ``kernel_path_template``.  The function resolves the appropriate path
    and loads the .npy file.
    """
    # Try per‑gauge mapping first
    if "kernel_paths" in cfg.get("adjoint_volume", {}):
        path_template = cfg["adjoint_volume"]["kernel_paths"][gauge]
    else:
        path_template = cfg["kernel_path_template"][gauge]
    path = path_template.format(L=L)
    return np.load(path)


def load_flip_counts(cfg: Dict[str, Any], L: int) -> np.ndarray:
    """Load the flip‑count array for the given lattice size."""
    path_template = cfg["flip_counts_path_template"]
    path = path_template.format(L=L)
    return np.load(path)


def main(argv: Sequence[str] | None = None) -> None:
    parser = argparse.ArgumentParser(description="Run adjoint volume sweeps with agent‑generated kernels")
    parser.add_argument("--config", required=True, help="Path to YAML configuration file")
    parser.add_argument("--output-dir", required=True, help="Directory where results CSV will be stored")
    args = parser.parse_args(argv)

    with open(args.config, "r", encoding="utf-8") as f:
        cfg = yaml.safe_load(f)

    out_dir = args.output_dir
    os.makedirs(out_dir, exist_ok=True)
    out_csv = os.path.join(out_dir, "adjoint_volume_summary.csv")

    # Collect parameter ranges
    b_values: List[float] = cfg.get("b_values", [])
    k_values: List[float] = cfg.get("k_values", [])
    n0_values: List[int] = cfg.get("n0_values", [])
    L_values: List[int] = cfg.get("L_values", [])
    gauge_groups: List[str] = cfg.get("adjoint_volume", {}).get("gauge_groups", [])

    # Run through all combinations
    with open(out_csv, "w", newline="", encoding="utf-8") as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=[
            "b", "k", "n0", "L", "gauge", "sigma_fund", "sigma_adj"
        ])
        writer.writeheader()
        for b, k_exp, n0, L in product(b_values, k_values, n0_values, L_values):
            for gauge in gauge_groups:
                # Load data
                kernel = load_kernel(cfg, gauge, L)
                _ = load_flip_counts(cfg, L)  # loaded for completeness but unused
                # Compute tensions
                sigma_f = fundamental_string_tension(kernel, b, k_exp)
                sigma_a = adjoint_string_tension(sigma_f, gauge)
                writer.writerow({
                    "b": b,
                    "k": k_exp,
                    "n0": n0,
                    "L": L,
                    "gauge": gauge,
                    "sigma_fund": sigma_f,
                    "sigma_adj": sigma_a,
                })
                print(f"Computed b={b}, k={k_exp}, n0={n0}, L={L}, gauge={gauge} → sigma_fund={sigma_f:.4f}, sigma_adj={sigma_a:.4f}")

    print(f"Saved results to {out_csv}")


if __name__ == "__main__":  # pragma: no cover
    main()